"""
Phase 3: Sudden Stops & CA Reversals
=====================================
Tests whether demographic structure predicts CA reversals and sudden stops.
Examines Z × KAOPEN, Z × NFA interactions, demographic tercile heterogeneity,
and opposing youth vs. aging channels.
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)


def run_regression(df, y_var, x_vars, label, feature_names=None):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    gls.fit(y, X, sub['iso3'].values, sub['year'].values)

    names = feature_names if feature_names else x_vars
    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    for i, name in enumerate(names):
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(names):
        sig = stars(gls.pvalues[i])
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")

    return result


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def fmt(val, se, p):
    s = stars(p)
    return f"{val:.4f}{s}", f"({se:.4f})"


def write_table(results, filename, title):
    """Write regression results as markdown table."""
    if not results:
        return

    lines = [f"# {title}\n"]

    all_vars = []
    for r in results:
        for k in r:
            if k.endswith('_coef'):
                vname = k.replace('_coef', '')
                if vname not in all_vars:
                    all_vars.append(vname)

    model_labels = [r['model'] for r in results]
    header = "| Variable | " + " | ".join(model_labels) + " |"
    sep = "|:---|" + "|".join(["---:" for _ in results]) + "|"
    lines.append(header)
    lines.append(sep)

    for var in all_vars:
        coef_row = f"| {var} |"
        se_row = "| |"
        for r in results:
            if f'{var}_coef' in r:
                c, s = fmt(r[f'{var}_coef'], r[f'{var}_se'], r[f'{var}_p'])
                coef_row += f" {c} |"
                se_row += f" {s} |"
            else:
                coef_row += " |"
                se_row += " |"
        lines.append(coef_row)
        lines.append(se_row)

    lines.append("|:---|" + "|".join(["---:" for _ in results]) + "|")
    n_row = "| N |"
    r2_row = "| R² |"
    nc_row = "| Countries |"
    for r in results:
        n_row += f" {r['n_obs']} |"
        r2_row += f" {r['r_squared']:.4f} |"
        nc_row += f" {r['n_countries']} |"
    lines.append(n_row)
    lines.append(r2_row)
    lines.append(nc_row)

    lines.append("\n*Panel GLS with country and year fixed effects. "
                 "Standard errors in parentheses.*")
    lines.append("*\\*p<0.1, \\*\\*p<0.05, \\*\\*\\*p<0.01*")

    path = OUT_TABLES / filename
    path.write_text('\n'.join(lines))
    print(f"\n  Saved: {path}")


def main():
    print("=" * 70)
    print("PHASE 3: SUDDEN STOPS & CA REVERSALS")
    print("=" * 70)

    df = pd.read_csv(DATA / "crises_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")
    print(f"CA reversals (≥3pp): {df['ca_reversal'].sum():.0f}")
    print(f"Sudden stops: {df['sudden_stop'].sum():.0f}")

    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'kaopen']

    # ── Table 1: CA Reversals & Sudden Stops ──
    print("\n" + "=" * 50)
    print("TABLE 1: CA REVERSALS & SUDDEN STOPS")
    print("=" * 50)

    results = []

    # M1: Z → CA reversal
    print("\n--- M1: Z → ca_reversal ---")
    r = run_regression(df, 'ca_reversal',
                       ['Z_1', 'Z_2', 'Z_3'] + controls,
                       'M1: Reversal')
    if r: results.append(r)

    # M2: Z → sudden stop
    print("\n--- M2: Z → sudden_stop ---")
    r = run_regression(df, 'sudden_stop',
                       ['Z_1', 'Z_2', 'Z_3'] + controls,
                       'M2: Sudden Stop')
    if r: results.append(r)

    # M3: Z × KAOPEN → reversal
    print("\n--- M3: Z × KAOPEN → ca_reversal ---")
    interact_vars = ['Z_1', 'Z_2', 'Z_3', 'Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen']
    available = [v for v in interact_vars if v in df.columns]
    r = run_regression(df, 'ca_reversal',
                       available + controls,
                       'M3: Z×KAOPEN')
    if r: results.append(r)

    # M4: Z × NFA → reversal
    print("\n--- M4: Z × NFA creditor → ca_reversal ---")
    if 'nfa_positive' in df.columns:
        df['Z_1_x_nfa_pos'] = df['Z_1'] * df['nfa_positive']
        df['Z_2_x_nfa_pos'] = df['Z_2'] * df['nfa_positive']
        df['Z_3_x_nfa_pos'] = df['Z_3'] * df['nfa_positive']
        r = run_regression(df, 'ca_reversal',
                           ['Z_1', 'Z_2', 'Z_3',
                            'Z_1_x_nfa_pos', 'Z_2_x_nfa_pos', 'Z_3_x_nfa_pos'] + controls,
                           'M4: Z×NFA')
        if r: results.append(r)

    write_table(results, "sudden_stops.md",
                "Sudden Stops & CA Reversals")

    # ── Table 2: Opposing Channels (Youth vs. Aging) ──
    print("\n" + "=" * 50)
    print("TABLE 2: OPPOSING CHANNELS — YOUTH VS. AGING")
    print("=" * 50)

    results_channels = []

    # M5: youth_dep + old_dep → reversal
    print("\n--- M5: youth_dep + old_dep → ca_reversal ---")
    r = run_regression(df, 'ca_reversal',
                       ['youth_dep', 'old_dep'] + controls,
                       'M5: Youth + Aging')
    if r: results_channels.append(r)

    # M6: youth_dep only
    print("\n--- M6: youth_dep → ca_reversal ---")
    r = run_regression(df, 'ca_reversal',
                       ['youth_dep'] + controls,
                       'M6: Youth only')
    if r: results_channels.append(r)

    # M7: old_dep only
    print("\n--- M7: old_dep → ca_reversal ---")
    r = run_regression(df, 'ca_reversal',
                       ['old_dep'] + controls,
                       'M7: Aging only')
    if r: results_channels.append(r)

    # M8: Stricter 5pp reversal
    print("\n--- M8: Z → ca_reversal_5pp ---")
    r = run_regression(df, 'ca_reversal_5pp',
                       ['Z_1', 'Z_2', 'Z_3'] + controls,
                       'M8: 5pp Reversal')
    if r: results_channels.append(r)

    write_table(results_channels, "reversal_channels.md",
                "CA Reversal Channels: Youth vs. Aging")

    # ── Table 3: By Demographic Tercile ──
    print("\n" + "=" * 50)
    print("TABLE 3: BY DEMOGRAPHIC TERCILE")
    print("=" * 50)

    results_tercile = []

    if 'demo_tercile' in df.columns:
        for tercile in ['early', 'mid', 'late']:
            sub = df[df['demo_tercile'] == tercile].copy()
            if len(sub) < 100:
                print(f"  {tercile}: insufficient obs ({len(sub)}), skipping")
                continue

            print(f"\n--- {tercile} demographic tercile ---")
            r = run_regression(sub, 'ca_reversal',
                               ['Z_1', 'Z_2', 'Z_3'] + controls,
                               f'{tercile.capitalize()} Trans.')
            if r: results_tercile.append(r)

    if results_tercile:
        write_table(results_tercile, "reversal_terciles.md",
                    "CA Reversals by Demographic Transition Stage")

    # ── Table 4: Post-Reversal Recovery Speed ──
    print("\n" + "=" * 50)
    print("TABLE 4: POST-REVERSAL RECOVERY")
    print("=" * 50)

    results_recovery = []

    # Construct: among reversal episodes, how fast does CA recover?
    # d_ca_gdp in t+1, t+2 conditional on reversal in t
    df = df.sort_values(['iso3', 'year'])
    df['reversal_lag1'] = df.groupby('iso3')['ca_reversal'].shift(1)
    df['reversal_lag2'] = df.groupby('iso3')['ca_reversal'].shift(2)
    df['d_ca_gdp_fwd1'] = df.groupby('iso3')['d_ca_gdp'].shift(-1)
    df['d_ca_gdp_fwd2'] = df.groupby('iso3')['d_ca_gdp'].shift(-2)

    # Recovery in t+1 after reversal in t
    post_rev = df[df['reversal_lag1'] == 1].copy()
    if len(post_rev) > 50:
        print(f"\n--- Post-reversal recovery (t+1), N={len(post_rev)} ---")
        r = run_regression(post_rev, 'd_ca_gdp',
                           ['Z_1', 'Z_2', 'Z_3'] + ['fiscal_bal_gdp', 'nfa_gdp_lag', 'kaopen'],
                           'Recovery t+1')
        if r: results_recovery.append(r)

    # Recovery in t+2 after reversal
    post_rev2 = df[df['reversal_lag2'] == 1].copy()
    if len(post_rev2) > 50:
        print(f"\n--- Post-reversal recovery (t+2), N={len(post_rev2)} ---")
        r = run_regression(post_rev2, 'd_ca_gdp',
                           ['Z_1', 'Z_2', 'Z_3'] + ['fiscal_bal_gdp', 'nfa_gdp_lag', 'kaopen'],
                           'Recovery t+2')
        if r: results_recovery.append(r)

    if results_recovery:
        write_table(results_recovery, "recovery_speed.md",
                    "Post-Reversal Recovery Speed")

    print("\n" + "=" * 70)
    print("PHASE 3 COMPLETE")
    print("=" * 70)


if __name__ == '__main__':
    main()
